#include "Conv.h"
#include "ap_int.h"



float relu(float x){
////#pragma HLS inline
	return x>0 ? x : 0;
}

void CONVOLUTION_LAYER_1(const float input_feature[image_Batch*INPUT_CHANNEL*CONV_1_INPUT_SIZE],
        const float weights[INPUT_CHANNEL*FILTER_SIZE+INPUT_CHANNEL*CONV_1_TYPE],
        const float bias[CONV_1_TYPE],
        float output_feature[image_Batch*CONV_1_TYPE*CONV_1_OUTPUT_SIZE],
        int init
        )
{
    float IBRAM[image_Batch][INPUT_CHANNEL][CONV_1_INPUT_HT][CONV_1_INPUT_WH];
    float W1BRAM[INPUT_CHANNEL][3][3];
    float W2BRAM[CONV_1_TYPE][INPUT_CHANNEL];
    float biasBRAM[CONV_1_TYPE];
    float TEMBRAM[image_Batch][INPUT_CHANNEL][CONV_1_OUTPUT_SIZE];
    float OBRAM[image_Batch][CONV_1_TYPE][CONV_1_OUTPUT_SIZE];

////#pragma HLS array_partition variable=W1BRAM complete dim=1
////#pragma HLS array_partition variable=W2BRAM complete dim=1
////#pragma HLS array_partition variable=biasBRAM complete dim=0
////#pragma HLS array_partition variable=TEMBRAM complete dim=2
////#pragma HLS array_partition variable=OBRAM complete dim=2
    copy_input_1:
    for(int batch=0; batch<image_Batch; batch++){
        int batch_offset = batch*INPUT_CHANNEL*CONV_1_INPUT_SIZE;
        copy_input_2:
        for(int j=0; j<INPUT_CHANNEL; j++){
            copy_input_3:
            for(int k=0; k<CONV_1_INPUT_HT; k++){
                copy_input_4:
                for(int l=0; l<CONV_1_INPUT_WH; l++){
////#pragma HLS PIPELINE II=1
                    IBRAM[batch][j][k][l] = input_feature[batch_offset
                                                        +k*CONV_1_INPUT_WH*INPUT_CHANNEL
                                                        +l*INPUT_CHANNEL
                                                        +j];
                }
            }
        }
    }

    if(init){
        copy_kernel_1:
        for(int i=0; i<INPUT_CHANNEL; i++){
            copy_kernel_2:
            for(int j=0; j<FILTER_HT; j++){
                for(int k=0; k<FILTER_WH; k++){
////#pragma HLS PIPELINE II=1
                    W1BRAM[i][j][k] = weights[j*FILTER_WH*INPUT_CHANNEL
                                            +k*INPUT_CHANNEL
                                            +i];
                }
            }
        }

        copy_kernel_3:
        int offset = INPUT_CHANNEL*FILTER_SIZE;
        for(int i=0; i<INPUT_CHANNEL; i++){
            copy_kernel_4:
            for(int j=0; j<CONV_1_TYPE; j++){
////#pragma HLS PIPELINE II=1
                W2BRAM[j][i] = weights[offset
                                        +i*CONV_1_TYPE
                                        +j];
            }
        }

        copy_bias:
        for(int i=0; i<CONV_1_TYPE; i++){
////#pragma HLS PIPELINE II=1
            biasBRAM[i] = bias[i];
        }
    }

    /*---------------conv1---------------*/
    BATCH:
    for(int batch_cnt=0; batch_cnt<image_Batch; batch_cnt++){
        ROW_K:
        for(int row_k=0; row_k<FILTER_HT; row_k++){
            COL_K:
            for(int col_k=0; col_k<FILTER_WH; col_k++){
                ROW:
                for(int row=0; row<CONV_1_OUTPUT_HT; row++){
                    COL:
                    for(int col=0; col<CONV_1_OUTPUT_WH; col++){
////#pragma HLS PIPELINE II=1
                        float mult[INPUT_CHANNEL];
////#pragma HLS array_partition variable=mult complete dim=0
                        D_OUT:
                        for(int co=0; co<INPUT_CHANNEL; co++){
////#pragma HLS unroll
                            mult[co] = IBRAM[batch_cnt][co][row+row_k][col+col_k]*W1BRAM[co][row_k][col_k];
                            if(row_k==0&&col_k==0)
                                TEMBRAM[batch_cnt][co][row*CONV_1_OUTPUT_WH+col] = mult[co];
                            else
                                TEMBRAM[batch_cnt][co][row*CONV_1_OUTPUT_WH+col] += mult[co];
                        }
                    }
                }
            }
        }
        OUT_C:
        for(int c=0; c<CONV_1_TYPE; c++){
            for(int posi=0; posi<CONV_1_OUTPUT_SIZE; posi++){
////#pragma HLS PIPELINE II=1
                OBRAM[batch_cnt][c][posi] = relu(W2BRAM[c][0]*TEMBRAM[batch_cnt][0][posi]
                                                +W2BRAM[c][1]*TEMBRAM[batch_cnt][1][posi]
                                                +bias[c]);
            }

        }
    }
    /*copy output*/
    copy_output:
    for(int i=0; i<image_Batch; i++){
        int batch_offset = i*CONV_1_OUTPUT_SIZE*CONV_1_TYPE;
        for(int j=0; j<CONV_1_TYPE; j++){
            int depth_offset = j*CONV_1_OUTPUT_SIZE;
            for(int k=0; k<CONV_1_OUTPUT_SIZE; k++){
////#pragma HLS PIPELINE II=1
                output_feature[batch_offset+depth_offset+k] = OBRAM[i][j][k];
            }
        }
    }


}


void CONVOLUTION_LAYER_2(const float input_feature[image_Batch*CONV_1_TYPE*CONV_1_OUTPUT_SIZE],
        const float weights[CONV_1_TYPE*FILTER_SIZE+CONV_2_TYPE*CONV_1_TYPE],
        const float bias[CONV_2_TYPE],
        float output_feature[image_Batch*CONV_2_TYPE*CONV_2_OUTPUT_SIZE],
        int init
        )
{
    float IBRAM[image_Batch][CONV_1_TYPE][CONV_1_OUTPUT_HT][CONV_1_OUTPUT_WH];
    float W1BRAM[CONV_1_TYPE][3][3];
    float W2BRAM[CONV_2_TYPE][CONV_1_TYPE];
    float biasBRAM[CONV_2_TYPE];
    float TEMBRAM[image_Batch][CONV_1_TYPE][CONV_2_OUTPUT_SIZE];
    float OBRAM[image_Batch][CONV_2_TYPE][CONV_2_OUTPUT_SIZE];

////#pragma HLS array_partition variable=W1BRAM complete dim=1
////#pragma HLS array_partition variable=W2BRAM complete dim=1
////#pragma HLS array_partition variable=biasBRAM complete dim=0
////#pragma HLS array_partition variable=TEMBRAM complete dim=2
////#pragma HLS array_partition variable=OBRAM complete dim=2

    copy_input_1:
    for(int batch=0; batch<image_Batch; batch++){
        int batch_offset = batch*CONV_1_TYPE*CONV_1_OUTPUT_SIZE;
        copy_input_2:
        for(int j=0; j<CONV_1_TYPE; j++){
            copy_input_3:
            for(int k=0; k<CONV_1_OUTPUT_HT; k++){
                copy_input_4:
                for(int l=0; l<CONV_1_OUTPUT_WH; l++){
////#pragma HLS pipeline II=1
                    IBRAM[batch][j][k][l] = input_feature[batch_offset
                                                        +j*CONV_1_OUTPUT_SIZE
                                                        +k*CONV_1_OUTPUT_WH
                                                        +l];
                }
            }
        }
    }

    if(init){
        copy_kernel_1:
        for(int i=0; i<CONV_1_TYPE; i++){
            copy_kernel_2:
            for(int j=0; j<FILTER_HT; j++){
                for(int k=0; k<FILTER_WH; k++){
//#pragma HLS pipeline II=1
                    W1BRAM[i][j][k] = weights[j*FILTER_WH*CONV_1_TYPE
                                            +k*CONV_1_TYPE
                                            +i];
                }
            }
        }

        copy_kernel_3:
        int offset = CONV_1_TYPE*FILTER_SIZE;
        for(int i=0; i<CONV_1_TYPE; i++){
            copy_kernel_4:
            for(int j=0; j<CONV_2_TYPE; j++){
//#pragma HLS pipeline II=1
                W2BRAM[j][i] = weights[offset
                                        +i*CONV_2_TYPE
                                        +j];
            }
        }

        copy_bias:
        for(int i=0; i<CONV_2_TYPE; i++){
//#pragma HLS pipeline II=1
            biasBRAM[i] = bias[i];
        }
    }

    /*---------------conv2---------------*/
    BATCH:
    for(int batch_cnt=0; batch_cnt<image_Batch; batch_cnt++){
        ROW_K:
        for(int row_k=0; row_k<3; row_k++){
            COL_K:
            for(int col_k=0; col_k<3; col_k++){
                ROW:
                for(int row=0; row<CONV_2_OUTPUT_HT; row++){
                    COL:
                    for(int col=0; col<CONV_2_OUTPUT_WH; col++){
//#pragma HLS pipeline II=1
                        float mult[CONV_1_TYPE];
//#pragma HLS array_partition variable=mult complete dim=0
                        D_OUT:
                        for(int co=0; co<CONV_1_TYPE; co++){
//#pragma HLS unroll
                            mult[co] = IBRAM[batch_cnt][co][row+row_k][col+col_k]*W1BRAM[co][row_k][col_k];
                            if(row_k==0&&col_k==0)
                                TEMBRAM[batch_cnt][co][row*CONV_2_OUTPUT_WH+col] = mult[co];
                            else
                                TEMBRAM[batch_cnt][co][row*CONV_2_OUTPUT_WH+col] += mult[co];
                        }
                    }
                }
            }
        }
        OUT_C:
        for(int c=0; c<CONV_2_TYPE; c++){
            OUT_P:
            for(int posi=0; posi<CONV_2_OUTPUT_SIZE; posi++){
                ROW_2:
                for(int p=0; p<CONV_1_TYPE; p++){
//#pragma HLS pipeline II=1
                    if(p==0)
                        OBRAM[batch_cnt][c][posi] = W2BRAM[c][p]*TEMBRAM[batch_cnt][p][posi]+bias[c];
                    else
                        OBRAM[batch_cnt][c][posi] += W2BRAM[c][p]*TEMBRAM[batch_cnt][p][posi];
                }
            }
        }
    }
    /*copy output*/
    copy_output:
    for(int i=0; i<image_Batch; i++){
        int batch_offset = i*CONV_2_OUTPUT_SIZE*CONV_2_TYPE;
        for(int j=0; j<CONV_2_TYPE; j++){
            int depth_offset = j*CONV_2_OUTPUT_SIZE;
            for(int k=0; k<CONV_2_OUTPUT_SIZE; k++){
//#pragma HLS pipeline II=1
                output_feature[batch_offset+depth_offset+k] = relu(OBRAM[i][j][k]);
            }
        }
    }

}
